#!/usr/bin/env python3
#
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This software product is a proprietary product of Mellanox Technologies Ltd.
# (the "Company") and all right, title, and interest in and to the software
# product, including all associated intellectual property rights, are and
# shall remain exclusively with the Company.
#
# This software product is governed by the End User License Agreement
# provided with the software product.
#

import argparse
import sys
import shlex
import json
import socket
import time
import os
import copy

__version__='v1.3.6'

try:
    from shlex import quote
except ImportError:
    from pipes import quote

class JsonRpcSnapException(Exception):
    def __init__(self, message):
        self.message = message

class JsonRpcVirtnetClient(object):
    decoder = json.JSONDecoder()

    def __init__(self, address, port, timeout=125.0):
        self.sock = None
        self._request_id = 0
        self.timeout = timeout
        try:
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.sock.connect((address, int(port)))
        except socket.error as ex:
            print("ERR: Can't connect to vhostd: %s" % (ex))
            print("     VDPA is not ready to accept commands")
            sys.exit(1)

    def __json_to_string(self, request):
        return json.dumps(request)

    def send(self, method, params=None):
        self._request_id += 1
        req = {
            'jsonrpc': '2.0',
            'method': method,
            'id': self._request_id
        }
        if params:
            req['params'] = copy.deepcopy(params)

        self.sock.sendall(self.__json_to_string(req).encode("utf-8"))
        return self._request_id

    def __string_to_json(self, request_str):
        try:
            obj, idx = self.decoder.raw_decode(request_str)
            return obj
        except ValueError:
            return None

    def recv(self):
        timeout = self.timeout
        start_time = time.time()
        response = None
        buf = ""

        while not response:
            try:
                timeout = timeout - (time.time() - start_time)
                self.sock.settimeout(timeout)
                buf += self.sock.recv(4096).decode("utf-8")
                response = self.__string_to_json(buf)
            except socket.timeout:
                print("ERR: response timeout")
                sys.exit(1)
            except ValueError:
                continue  # incomplete response; keep buffering

        self.sock.close()

        if not response:
            raise JsonRpcSnapException("Response Timeout")
        return response

    def call(self, method, params={}):
        if params:
            print(params)
        req_id = self.send(method, params)
        response = self.recv()

        if 'error' in response:
            params["method"] = method
            params["req_id"] = req_id
            msg = "\n".join(["request:", "%s" % json.dumps(params, indent=2),
                             "Got JSON-RPC error response",
                             "response:",
                             json.dumps(response['error'], indent=2)])
            raise JsonRpcSnapException(msg)

        return response['result']

def call_rpc_func(args):
    args.func(args)

def execute_script(parser, client, fd):
    executed_rpc = ""
    for rpc_call in map(str.rstrip, fd):
        if not rpc_call.strip():
            continue
        executed_rpc = "\n".join([executed_rpc, rpc_call])
        args = parser.parse_args(shlex.split(rpc_call))
        args.client = client
        try:
            call_rpc_func(args)
        except JsonRpcSnapException as ex:
            print("Exception:")
            print(executed_rpc.strip() + " <<<")
            print(ex.message)
            exit(1)

def mgmtpf(args):
    params = {}
    if args.add_pf:
        params['add'] = args.add_pf
        params['dev'] = args.device
    if args.remove_pf:
        params['remove'] = args.remove_pf
        params['dev'] = args.device
    if args.list_pf:
        params['list'] = args.list_pf

    result = args.client.call('mgmtpf', params)
    print(json.dumps(result, indent=2))

def mgmtvf(args):
    params = {}
    if args.add_vf:
        params['add'] = args.add_vf
        params['vfdev'] = args.pfvfdev
        params['socket_file'] = args.vhost_socket
        params['uuid'] = args.vm_uuid
    if args.remove_vf:
        params['remove'] = args.remove_vf
        params['vfdev'] = args.pfvfdev
    if args.list_vf:
        params['list'] = args.list_vf
        params['mgmtpf'] = args.pfvfdev
    if args.info_vf:
        params['info'] = args.info_vf
        params['vfdev'] = args.pfvfdev
    if args.debug_vf:
        params['vfdev'] = args.pfvfdev
        params['debug'] = args.debug_vf
        if args.test_operation:
            params['test_operation'] = args.test_operation
        if args.test_size_mode:
            if args.test_size_mode == "byte":
                params['size_mode'] = 2
            else:
                params['size_mode'] = 1

    result = args.client.call('vf', params)
    print(json.dumps(result, indent=2))

def int_hex(x):
    return int(x, 16)

def version(args):
    params = {}

    result = args.client.call('version', params)
    print(json.dumps(result, indent=2))

def main():
    server_addr='127.0.0.1'
    server_port='12190'
    timeout=125.0

    parser = argparse.ArgumentParser(
        description='Nvidia vhostd-rpc command line interface ' + __version__)
    subparsers = parser.add_subparsers(help='** Use -h for sub-command usage',
                                       dest='called_rpc_name')

    # version
    p = subparsers.add_parser('version', help='show vhostd version info')
    p.set_defaults(func=version)

    # mgmtpf
    p = subparsers.add_parser('mgmtpf', help='Management PF device')
    group = p.add_mutually_exclusive_group()
    group.add_argument('-a', '--add', action='store_true', dest='add_pf', help="add a pci device")
    group.add_argument('-r', '--remove', action='store_true', dest='remove_pf', help="remove a pci device")
    group.add_argument('-l', '--list', action='store_true',  dest='list_pf', help="list all PF devices")
    p.add_argument(
        'device',
        metavar='DEVICE',
        nargs="?",
        help="""
    Device specified as PCI "domain:bus:slot.func" syntax or "bus:slot.func" syntax.
    For device add/remove to drivers, they may be referred to by interface name.
    """)
    p.set_defaults(func=mgmtpf)

    # mgmtvf
    p = subparsers.add_parser('vf', help='Management VF device')
    group = p.add_mutually_exclusive_group()
    group.add_argument('-a', '--add', action='store_true', dest='add_vf', help="add a pci device")
    group.add_argument('-r', '--remove', action='store_true', dest='remove_vf', help="remove a pci device")
    group.add_argument('-l', '--list', action='store_true', dest='list_vf', help="list all VF devices of PF device")
    group.add_argument('-i', '--info', action='store_true', dest='info_vf', help="show specified VF device information")
    group.add_argument('-d', '--debug', action='store_true', dest='debug_vf', help="test VF device debug")
    p.add_argument(
        'pfvfdev',
        metavar='DEVICE',
        nargs="?",
        help="""
    Device specified as PCI "domain:bus:slot.func" syntax or "bus:slot.func" syntax.
    For device add/remove to drivers, they may be referred to by interface name.
    """)
    p.add_argument('-o', metavar='test_operation', dest='test_operation', type=int,
                    help='Device test operation: 1=running; 2=quiesced; 3=freeze; 4=start_logging; 5=stop_logging; 6=get_lm_status')
    p.add_argument('-b', metavar='test_size_mode', dest='test_size_mode', type=int,
                    help='Device test mode: 1=bit; 2=byte', default=2)
    p.add_argument('-v', metavar='vhost_socket', dest='vhost_socket', type=str,
                    help='Vhost socket file name')
    p.add_argument('-u', metavar='vm_uuid', dest='vm_uuid', type=str,
                    help='Virtual machine UUID')
    p.set_defaults(func=mgmtvf)

    args = parser.parse_args()
    if args.called_rpc_name == "mgmtpf":
        if args.add_pf or args.remove_pf:
            dev = args.device
            if not dev:
                print("Error: No device specified for add/remove action.")
                parser.print_usage()
                sys.exit(1)
        if not args.add_pf and not args.remove_pf and not args.list_pf:
                print("Error: invalid action.")
                parser.print_usage()
                sys.exit(1)
    if args.called_rpc_name == "vf":
        if args.add_vf:
            vfdev = args.pfvfdev
            if not vfdev:
                print("Error: No vf device specified for add vf action.")
                parser.print_usage()
                sys.exit(1)
        if args.remove_vf:
            vfdev = args.pfvfdev
            if not vfdev:
                print("Error: No vf device specified for remove vf action.")
                parser.print_usage()
                sys.exit(1)
        if args.list_vf:
            pfdev = args.pfvfdev
            if not pfdev:
                print("Error: No pf device specified for list vf action.")
                parser.print_usage()
                sys.exit(1)
        if args.info_vf:
            vfdev = args.pfvfdev
            if not vfdev:
                print("Error: No vf device specified for add vf action.")
                parser.print_usage()
                sys.exit(1)
        if args.debug_vf:
            if not args.pfvfdev:
                print("Error: No vf device specified for test vf action.")
                parser.print_usage()
                sys.exit(1)
            if not args.test_operation or args.test_operation > 6:
                print("Error: invalid debug operation.")
                parser.print_usage()
                sys.exit(1)
            if args.test_operation == 4 or args.test_operation == 5:
                if not args.test_size_mode or args.test_size_mode > 2:
                    print("Error: invalid test start_logging parameters.")
                    parser.print_usage()
                    sys.exit(1)
        if not args.add_vf and not args.remove_vf and not args.list_vf and not args.info_vf and not args.debug_vf:
                print("Error: invalid action.")
                parser.print_usage()
                sys.exit(1)
    args.client = JsonRpcVirtnetClient(server_addr, server_port, timeout)
    if hasattr(args, 'func'):
        try:
            call_rpc_func(args)
        except JsonRpcSnapException as ex:
            print(ex)
            exit(1)
    elif sys.stdin.isatty():
        # No arguments and no data piped through stdin
        parser.print_help()
        exit(1)
    else:
        execute_script(parser, args.client, sys.stdin)

if __name__ == "__main__":
    main()
